A gentle introduction to SumProductTransform library
This introduction uses several unregistered libraries, namely ToyProblems.jl, SumProductTransform.jl which depends on Unitary.jl. The best is to instantiate environment in example/ directory, which should continue all you need including Pluto.
The intruduction starts with a classic Gaussian Mixture Model, continues with a simple Sum Product Network and graduates with Sum Product Transform Network.
Before we dive into real business, we import libraries and define a convenient function for plotting densities and data.
x
md"""# A gentle introduction to SumProductTransform libraryThis introduction uses several unregistered libraries, namely `ToyProblems.jl`, `SumProductTransform.jl` which depends on `Unitary.jl`. The best is to `instantiate` environment in `example/` directory, which should continue all you need including Pluto.The intruduction starts with a classic **Gaussian Mixture Model**, continues with a simple **Sum Product Network** and graduates with **Sum Product Transform Network**.Before we dive into real business, we import libraries and define a convenient function for plotting densities and data."""x
begin using Pkg Pkg.activate(".") using ToyProblems, Distributions, SumProductTransform, Unitary, Flux, Setfield using SumProductTransform: fit!,logpdf using ToyProblems: flower2 using Unitary: ScaleShift, SVDDense using Plots plotly()end;A plotting function will show the density of a fitted model and that of with training data on top
x
md"""A plotting function will show the density of a fitted model and that of with training data on top""" x
function plot_contour(m, x, title = nothing) levels = quantile(exp.(logpdf(m, x)), 0.01:0.09:0.99) δ = levels[1] / 10 levels = vcat(collect(levels[1] - 10δ:δ:levels[1] - δ), levels) xr = range(minimum(x[1,:]) - 1 , maximum(x[1,:])+ 1 , length = 200) yr = range(minimum(x[2,:]) - 1 , maximum(x[2,:])+ 1 , length = 200) p1 = Plots.contour(xr, yr, (x...) -> exp(logpdf(m, [x[1],x[2]])[1])) p2 = deepcopy(p1) xx = x[:,sample(1:size(x,2), 100, replace = false)] scatter!(p2, x[1,:], x[2,:], alpha = 0.4) p = plot(p1, p2) pend;Let's create training samples from Flower dataset with nine petals.
x
md""" Let's create training samples from **Flower** dataset with nine petals."""x
x = flower2(999, npetals = 9);Initialize dimension of data d, batchsize in stochastic gradient descend, and number of training steps
x
md""" Initialize dimension of data `d`, batchsize in stochastic gradient descend, and number of training steps"""xxxxxxxxxxbegin d = size(x,1) batchsize = 100 nsteps = 20000end;Gaussian Mixture Model
gmm with 144 components
x
md"""### Gaussian Mixture Model`gmm` with 144 components"""xxxxxxxxxxbegin ngmm_components = 144 init_normal(d) = TransformationNode(SVDDense(d), MvNormal(d, 1f0)) gmm_components = [init_normal(d) for i in 1:ngmm_components] gmm = SumNode(gmm_components)end;x
fit!(gmm, x, batchsize, nsteps);xxxxxxxxxxplot_contour(gmm, x)Sum Product network
xxxxxxxxxxmd"""### Sum Product network"""x
begin spt_ncomponents = 9 Normal1D() = TransformationNode(ScaleShift(1), MvNormal(1, 1f0)); spn_components = map(1:spt_ncomponents) do _ p₁ = SumNode([Normal1D() for _ in 1:spt_ncomponents]) p₂ = SumNode([Normal1D() for _ in 1:spt_ncomponents]) p₁₂ = ProductNode((p₁, p₂)) end spn = SumNode(spn_components)end;x
fit!(spn, x, batchsize, nsteps);x
plot_contour(spn, x)Sum Product Transform network
with affine transformations and Normal distribution on leaves
xxxxxxxxxxmd"""### Sum Product Transform networkwith affine transformations and Normal distribution on leaves"""x
begin nsptn_components = 3 global sptn = TransformationNode(ScaleShift(2), MvNormal(2, 1f0)); for i in 1:3 global sptn sptn = SumNode([TransformationNode(SVDDense(2), sptn) for i in 1:nsptn_components]) endend;x
fit!(sptn, x, batchsize, nsteps);x
plot_contour(sptn, x)Sum Product Transform network
with nonlinear transformation on leaves
x
md"""### Sum Product Transform network with nonlinear transformation on leaves"""begin leaf = ProductNode(( SumNode([TransformationNode(Chain(SVDDense(1, selu), ScaleShift(1)), MvNormal(1, 1f0)) for _ in 1:3]), SumNode([TransformationNode(Chain(SVDDense(1, selu), ScaleShift(1)), MvNormal(1, 1f0)) for _ in 1:3]), )) global sptn2 = leaf for i in 1:3 global sptn2 sptn2 = SumNode([TransformationNode(SVDDense(2), sptn2) for i in 1:3]) endend;xxxxxxxxxxfit!(sptn2, x, batchsize, nsteps);xxxxxxxxxxplot_contour(sptn2, x)